# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import abc
import warnings

import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs.common import EnvBase


class ModelBasedEnvBase(EnvBase):
    """Basic environment for Model Based RL sota-implementations.

    Wrapper around the model of the MBRL algorithm.
    It is meant to give an env framework to a world model (including but not limited to observations, reward, done state and safety constraints models).
    and to behave as a classical environment.

    This is a base class for other environments and it should not be used directly.

    Example:
        >>> import torch
        >>> from tensordict import TensorDict
        >>> from torchrl.data import Composite, Unbounded
        >>> class MyMBEnv(ModelBasedEnvBase):
        ...     def __init__(self, world_model, device="cpu", dtype=None, batch_size=None):
        ...         super().__init__(world_model, device=device, dtype=dtype, batch_size=batch_size)
        ...         self.observation_spec = Composite(
        ...             hidden_observation=Unbounded((4,))
        ...         )
        ...         self.state_spec = Composite(
        ...             hidden_observation=Unbounded((4,)),
        ...         )
        ...         self.action_spec = Unbounded((1,))
        ...         self.reward_spec = Unbounded((1,))
        ...
        ...     def _reset(self, tensordict: TensorDict) -> TensorDict:
        ...         tensordict = TensorDict(
        ...             batch_size=self.batch_size,
        ...             device=self.device,
        ...         )
        ...         tensordict = tensordict.update(self.state_spec.rand())
        ...         tensordict = tensordict.update(self.observation_spec.rand())
        ...         return tensordict
        >>> # This environment is used as follows:
        >>> import torch.nn as nn
        >>> from torchrl.modules import MLP, WorldModelWrapper
        >>> world_model = WorldModelWrapper(
        ...     TensorDictModule(
        ...         MLP(out_features=4, activation_class=nn.ReLU, activate_last_layer=True, depth=0),
        ...         in_keys=["hidden_observation", "action"],
        ...         out_keys=["hidden_observation"],
        ...     ),
        ...     TensorDictModule(
        ...         nn.Linear(4, 1),
        ...         in_keys=["hidden_observation"],
        ...         out_keys=["reward"],
        ...     ),
        ... )
        >>> env = MyMBEnv(world_model)
        >>> tensordict = env.rollout(max_steps=10)
        >>> print(tensordict)
        TensorDict(
            fields={
                action: Tensor(torch.Size([10, 1]), dtype=torch.float32),
                done: Tensor(torch.Size([10, 1]), dtype=torch.bool),
                hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32),
                next: LazyStackedTensorDict(
                    fields={
                        hidden_observation: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
                    batch_size=torch.Size([10]),
                    device=cpu,
                    is_shared=False),
                reward: Tensor(torch.Size([10, 1]), dtype=torch.float32)},
            batch_size=torch.Size([10]),
            device=cpu,
            is_shared=False)


    Properties:
        observation_spec (Composite): sampling spec of the observations;
        action_spec (TensorSpec): sampling spec of the actions;
        reward_spec (TensorSpec): sampling spec of the rewards;
        input_spec (Composite): sampling spec of the inputs;
        batch_size (torch.Size): batch_size to be used by the env. If not set, the env accept tensordicts of all batch sizes.
        device (torch.device): device where the env input and output are expected to live

    Args:
        world_model (nn.Module): model that generates world states and its corresponding rewards;
        params (List[torch.Tensor], optional): list of parameters of the world model;
        buffers (List[torch.Tensor], optional): list of buffers of the world model;
        device (torch.device, optional): device where the env input and output are expected to live
        dtype (torch.dtype, optional): dtype of the env input and output
        batch_size (torch.Size, optional): number of environments contained in the instance
        run_type_check (bool, optional): whether to run type checks on the step of the env

    Methods:
        step (TensorDict -> TensorDict): step in the environment
        reset (TensorDict, optional -> TensorDict): reset the environment
        set_seed (int -> int): sets the seed of the environment
        rand_step (TensorDict, optional -> TensorDict): random step given the action spec
        rollout (Callable, ... -> TensorDict): executes a rollout in the environment with the given policy (or random
            steps if no policy is provided)

    """

    def __init__(
        self,
        world_model: TensorDictModule,
        params: list[torch.Tensor] | None = None,
        buffers: list[torch.Tensor] | None = None,
        device: DEVICE_TYPING = "cpu",
        batch_size: torch.Size | None = None,
        run_type_checks: bool = False,
    ):
        super().__init__(
            device=device,
            batch_size=batch_size,
            run_type_checks=run_type_checks,
        )
        self.world_model = world_model.to(self.device)
        self.world_model_params = params
        self.world_model_buffers = buffers

    @classmethod
    def __new__(cls, *args, **kwargs):
        return super().__new__(
            cls, *args, _inplace_update=False, _batch_locked=False, **kwargs
        )

    def set_specs_from_env(self, env: EnvBase):
        """Sets the specs of the environment from the specs of the given environment."""
        device = self.device
        output_spec = env.output_spec.clone()
        input_spec = env.input_spec.clone()
        if device is not None:
            output_spec = output_spec.to(device)
            input_spec = input_spec.to(device)
        self.__dict__["_output_spec"] = output_spec
        self.__dict__["_input_spec"] = input_spec
        self.empty_cache()

    def _step(
        self,
        tensordict: TensorDict,
    ) -> TensorDict:
        # step method requires to be immutable
        tensordict_out = tensordict.clone(recurse=False)
        # Compute world state
        if self.world_model_params is not None:
            tensordict_out = self.world_model(
                tensordict_out,
                params=self.world_model_params,
                buffers=self.world_model_buffers,
            )
        else:
            tensordict_out = self.world_model(tensordict_out)
        # done can be missing, it will be filled by `step`
        tensordict_out = tensordict_out.select(
            *self.observation_spec.keys(),
            *self.full_done_spec.keys(),
            *self.full_reward_spec.keys(),
            strict=False,
        )
        return tensordict_out

    @abc.abstractmethod
    def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
        raise NotImplementedError

    def _set_seed(self, seed: int | None) -> None:
        warnings.warn("Set seed isn't needed for model based environments")
